|
51 | 51 |
|
52 | 52 | logger = get_logger(__name__, log_level="INFO")
|
53 | 53 |
|
| 54 | +DATASET_NAME_MAPPING = { |
| 55 | + "lambdalabs/pokemon-blip-captions": ("image", "text"), |
| 56 | +} |
| 57 | + |
54 | 58 |
|
55 | 59 | def parse_args():
|
56 | 60 | parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
@@ -193,6 +197,13 @@ def parse_args():
|
193 | 197 | parser.add_argument(
|
194 | 198 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
195 | 199 | )
|
| 200 | + parser.add_argument( |
| 201 | + "--_snr_gamma", |
| 202 | + type=float, |
| 203 | + default=None, |
| 204 | + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " |
| 205 | + "More details here: https://arxiv.org/abs/2303.09556.", |
| 206 | + ) |
196 | 207 | parser.add_argument(
|
197 | 208 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
198 | 209 | )
|
@@ -325,9 +336,32 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
|
325 | 336 | return f"{organization}/{model_id}"
|
326 | 337 |
|
327 | 338 |
|
328 |
| -dataset_name_mapping = { |
329 |
| - "lambdalabs/pokemon-blip-captions": ("image", "text"), |
330 |
| -} |
| 339 | +def expand_tensor(arr, timesteps, broadcast_shape): |
| 340 | + """ |
| 341 | + Extract values from a 1-D numpy array for a batch of indices. |
| 342 | + Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 |
| 343 | + """ |
| 344 | + res = arr.to(device=timesteps.device)[timesteps].float() |
| 345 | + while len(res.shape) < len(broadcast_shape): |
| 346 | + res = res[..., None] |
| 347 | + return res.expand(broadcast_shape) |
| 348 | + |
| 349 | + |
| 350 | +def compute_snr(noise_scheduler): |
| 351 | + """ |
| 352 | + Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 |
| 353 | + """ |
| 354 | + alphas_cumprod = noise_scheduler.alphas_cumprod |
| 355 | + sqrt_alphas_cumprod = alphas_cumprod**0.5 |
| 356 | + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 |
| 357 | + |
| 358 | + def fn(timesteps): |
| 359 | + alpha = expand_tensor(sqrt_alphas_cumprod, timesteps, timesteps.shape) |
| 360 | + sigma = expand_tensor(sqrt_one_minus_alphas_cumprod, timesteps, timesteps.shape) |
| 361 | + snr = (alpha / sigma) ** 2 |
| 362 | + return snr |
| 363 | + |
| 364 | + return fn |
331 | 365 |
|
332 | 366 |
|
333 | 367 | def main():
|
@@ -476,6 +510,9 @@ def load_model_hook(models, input_dir):
|
476 | 510 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
477 | 511 | )
|
478 | 512 |
|
| 513 | + if args.snr_gamma is not None: |
| 514 | + snr_fn = compute_snr(noise_scheduler) |
| 515 | + |
479 | 516 | # Initialize the optimizer
|
480 | 517 | if args.use_8bit_adam:
|
481 | 518 | try:
|
@@ -526,7 +563,7 @@ def load_model_hook(models, input_dir):
|
526 | 563 | column_names = dataset["train"].column_names
|
527 | 564 |
|
528 | 565 | # 6. Get the column names for input/target.
|
529 |
| - dataset_columns = dataset_name_mapping.get(args.dataset_name, None) |
| 566 | + dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) |
530 | 567 | if args.image_column is None:
|
531 | 568 | image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
|
532 | 569 | else:
|
@@ -734,7 +771,23 @@ def collate_fn(examples):
|
734 | 771 |
|
735 | 772 | # Predict the noise residual and compute loss
|
736 | 773 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
737 |
| - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
| 774 | + |
| 775 | + if args.snr_gamma is None: |
| 776 | + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
| 777 | + else: |
| 778 | + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. |
| 779 | + # Since we predict the noise instead of x_0, the original formulation is slightly changed. |
| 780 | + # This is discussed in Section 4.2 of the same paper. |
| 781 | + snr = snr_fn(timesteps) |
| 782 | + mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( |
| 783 | + dim=1 |
| 784 | + )[0] |
| 785 | + # We first calculate the original loss. Then we mean over the non-batch dimensions and |
| 786 | + # rebalance the sample-wise losses with their respective loss weights. |
| 787 | + # Finally, we take the mean of the rebalanced loss. |
| 788 | + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") |
| 789 | + loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights |
| 790 | + loss = (mse_loss_weights * loss).mean() |
738 | 791 |
|
739 | 792 | # Gather the losses across all processes for logging (if we use distributed training).
|
740 | 793 | avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
|
|
0 commit comments